"""
#
# Flow stability for dynamic community detection https://arxiv.org/abs/2101.06131v2
#
# Copyright (C) 2021 Alexandre Bovet <alexandre.bovet@maths.ox.ac.uk>
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.


This script computes the transition probabilities between clusters used to draw Fig. 9 and S5


saves the results as 'complenet_forw_tau_w3650.00_prob_flow_integ_{orig_clust_name}_clust.pickle'
"""


import sys
import os
PACKAGE_PARENT = '..'
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))

import numpy as np

from multiprocessing import Pool

import pandas as pd

from TemporalNetwork import ContTempNetwork, set_to_zeroes

from SparseStochMat import inplace_csr_row_normalize, sparse_matmul

from scipy.sparse import eye

from collections import Counter

raise Exception

#%%
nproc = 6

datadir = '../paper_data/aps/lapl_intertransmat'

net_name = 'aps_monthly'

# this is the author name disambiguation of the APS dataset
# available at Supplementary Material at https://doi.org/10.1126/science.aaf5239
df_authors = pd.read_csv('../data/aps/APS_authors.csv', index_col=0)

network = 'complenet'

tau_w = 3650.0

lin = False

tol = 1e-9

net_file = os.path.join('../paper_data/aps','aps_monthly_lcc_net')

net = ContTempNetwork.load(net_file, attributes_list = ['node_to_label_dict',
                          'events_table',
                          'times',
                          'time_grid',
                          'num_nodes',
                          'time_slices_bounds',
                          'time_slices_bounds_datetimes'])

num_nodes = net.num_nodes

t_s = 10
label_to_node_dict = {l : n for n, l in net.node_to_label_dict.items()}

#%%

# list of journals in which each author published
df_auth_journal = pd.read_csv('../data/aps/df_author_journal.csv', index_col='author')

class_dict_journ = {journ : set([label_to_node_dict[l] for \
                               l in df.index.tolist() if l in label_to_node_dict]) for \
              journ, df in df_auth_journal.groupby('top_non_prl')}
class_dict_journ['other'] = set([label_to_node_dict[l] for l in \
                           df_auth_journal.loc[df_auth_journal.top_non_prl.isnull()].index.tolist()\
                               if l in label_to_node_dict])
    
node_to_journal_dict = {label_to_node_dict[a] : c for a,c in df_auth_journal.top_non_prl.iteritems()\
                                if a in label_to_node_dict}
    
# list of countries (in affiliations) for each author
df_auth_country = pd.read_csv('../data/aps/df_author_country.csv', index_col='author')
df_auth_country.top_country.loc[df_auth_country.top_country.isna()] = 'undefined'

class_dict_country = {count : set([label_to_node_dict[l] for \
                               l in df.index.tolist() if l in label_to_node_dict]) for \
              count, df in df_auth_country.groupby('top_country')}
    

node_to_country_dict = {label_to_node_dict[a] : c for a,c in df_auth_country.top_country.iteritems()\
                                if a in label_to_node_dict}

    
#%%
multi_res = pd.read_pickle(os.path.join('../paper_data/aps', 'multi_complenet_partitions_all_comps_active_nodes_sorted.pickle'))

# %% pick initial cluster    
    


#cluster from forw 00s
# orig_clust_id = 3 #UK/Belgian cluster
# orig_clust = multi_res[tau_w]['partitions_forw'][0].cluster_list[orig_clust_id]
# orig_clust_name = 'UK-Finland-Belgium'


# orig_clust_id = 18 #USA/Hungary
# orig_clust = multi_res[tau_w]['partitions_forw'][0].cluster_list[orig_clust_id]
# orig_clust_name = 'USA-Hungary'

orig_clust_id = 1 #USA/Italy
orig_clust = multi_res[tau_w]['partitions_forw'][0].cluster_list[orig_clust_id]
orig_clust_name = 'USA-Italy'


names=[(df_authors.iloc[net.node_to_label_dict[n]]['name'], 
        node_to_country_dict[n], 
        node_to_journal_dict[n]) for n in orig_clust]
print('\n'.join([str(n) for n in names])) 

Counter([c for _,c,_ in names]).most_common() 


# we are doing forw diffusion, reverse_time
reverse_time = True
dire = 'forw'
tau_w = 3650.0
#%% load T

    
def load_trans_mat(k_range, tau_w, lin, reverse_time, tol, verbose=False):
    """loads, computes and returns the transition matrix computed from k_range[0]
        to k_range[-1] and the corresponding time duration.
        
        used to quickly start again an integral computation.
    """
    
    if reverse_time:
        time_direction_str = 'reversed_'
    else:
        time_direction_str = ''
        
    
    if lin:
        
        T_file = os.path.join(datadir, net_name + \
                         f'_tau_w{tau_w:.3e}' + '_int{k:06d}__' + \
                         time_direction_str + 'lin_trans_mat')
            
        
    else:
        T_file = os.path.join(datadir, net_name + \
                         f'_tau_w{tau_w:.3e}' + '_int{k:06d}__' + \
                         time_direction_str + 'trans_mat')
    

    
    
    
    T = eye(num_nodes, format='csr', dtype=np.float64)
    
    integr_time = 0
    
    for k in k_range:
        TM = ContTempNetwork.load_T(T_file.format(k=k))
        
        if lin:
            Tk = TM['T_lin'][1/tau_w][t_s]
        else:
            Tk = TM['T'][1/tau_w]
            
        set_to_zeroes(Tk, tol=tol)
        inplace_csr_row_normalize(Tk)
            
        T = sparse_matmul(T,Tk.tocsr(), verbose=verbose, log_message='T')
        set_to_zeroes(T, tol=tol)
        inplace_csr_row_normalize(T)
        
        integr_time += abs(TM['_t_stop_laplacians']-TM['_t_start_laplacians'])

    return T, integr_time



#%% T forw 2010 to 2000


k_range = range(108,118)
if reverse_time:
    k_range=range(118-1,108-1,-1)
    



T_2010_to_2000, integr_time = load_trans_mat(k_range, tau_w, lin, reverse_time, tol)


#%% 
dec = '00s'

T_2010_to_2000_clust = T_2010_to_2000[list(orig_clust),:]

part_90s_forw = multi_res[tau_w]['partitions_forw'][1]



def _worker(c):
    return T_2010_to_2000_clust[:,list(c)].sum()

with Pool(nproc) as pool:
    T_2010_to_2000_double_clust = pool.map(_worker, part_90s_forw.cluster_list,
                      chunksize=1)
    

# prob from orig clust to all the clusts of 90s forw
T_clust_clust_2010_to_2000 = np.array(T_2010_to_2000_double_clust)/sum(T_2010_to_2000_double_clust)

dec = '00s'
# pd.to_pickle(T_clust_clust_00s, f'../paper_data/aps/{network}_{dire}_flow_prob_tau_w{tau_w:.3e}_clustered_T_{dec}.pickle',
#               protocol=4)


#%% top clusts from 90s

dec = '90s'
# how many have more than 5% prob
print(np.sort(T_clust_clust_2010_to_2000)[::-1][:10])
num_5p = (T_clust_clust_2010_to_2000[np.argsort(T_clust_clust_2010_to_2000)[::-1]]>0.04).sum()
print(num_5p)
    
clusts_90s_forw_idx = np.argsort(T_clust_clust_2010_to_2000)[::-1][:num_5p]


clusts_90s_forw = []
for idx in clusts_90s_forw_idx:
    clusts_90s_forw.append(part_90s_forw.cluster_list[idx])
    
    
print(Counter([node_to_journal_dict[n] for n in clusts_90s_forw[0]]).most_common())
print(Counter([node_to_country_dict[n] for n in clusts_90s_forw[0]]).most_common())
    

dec = '90s'
# pd.to_pickle(clusts_90s_forw,os.path.join('../paper_data/aps',f'{network}_{dire}_top_clusts_tau_w{tau_w:.3e}_{dec}.pickle'),
#               protocol=4)
    
#%% T 90s 
    
    

k_range = range(98,108)
if reverse_time:
    k_range = range(108-1,98-1,-1)    



T_2000_to_1990, integr_time = load_trans_mat(k_range, tau_w, lin, reverse_time, tol)


#%%

dec = '90s'
def _worker(params):
    c, T_clust = params
    return T_clust[:,list(c)].sum()

    
T_clust_T_2000_to_1990 = []
T_clust_clust_2000_to_1990 = []

part_80s_forw = multi_res[tau_w]['partitions_forw'][2]

for i,clust in enumerate(clusts_90s_forw):
    
    print(i)
    
    
    T_clust = T_2000_to_1990[list(clust),:]
        
    with Pool(nproc) as pool:
        T_clust_T_2000_to_1990.append(pool.map(_worker, [(c, T_clust) for c in part_80s_forw.cluster_list]))
    
        
    
    T_clust_clust_2000_to_1990.append(np.array(T_clust_T_2000_to_1990[i])/sum(T_clust_T_2000_to_1990[i]))
    

dec = '90s'
# pd.to_pickle(T_clust_clust_90s, f'../paper_data/aps/{network}_{dire}_flow_prob_tau_w{tau_w:.3e}_clustered_T_{dec}.pickle',
#                   protocol=4)

#%% top clusts from 80s

T_clust_2010_to_1990 = T_clust_clust_2010_to_2000[clusts_90s_forw_idx] @ np.vstack(T_clust_clust_2000_to_1990)

dec = '80s'
# how many have more than 5% prob
# print(np.sort(T_clust_clust_90s[0])[::-1][:10])
# num_5p = (T_clust_clust_90s[0][np.argsort(T_clust_clust_90s[0])[::-1]]>0.05).sum()
# print(num_5p)

print(np.sort(T_clust_2010_to_1990)[::-1][:10])
num_5p = (T_clust_2010_to_1990[np.argsort(T_clust_2010_to_1990)[::-1]]>0.04).sum()
print(num_5p)


clusts_80s_forw_idx = np.argsort(T_clust_2010_to_1990)[::-1][:num_5p]

clusts_80s_2000_to_1990_forw_idx = []
for p in T_clust_clust_2000_to_1990:
    clusts_80s_2000_to_1990_forw_idx.append(np.argsort(p)[::-1][:sum(np.sort(p)[::-1]>0.04)])



clusts_80s_forw = []
for clusts_idx in clusts_80s_forw_idx:
    clusts_80s_forw.append(part_80s_forw.cluster_list[clusts_idx])
    
# print(Counter([node_to_journal_dict[n] for n in clusts_80s_forw[0][0]]).most_common(10))
# print(Counter([node_to_country_dict[n] for n in clusts_80s_forw[0][0]]).most_common())



dec = '80s'
# pd.to_pickle(clusts_80s_forw,os.path.join('../paper_data/aps',f'{network}_{dire}_top_clusts_tau_w{tau_w:.3e}_{dec}.pickle'),
#               protocol=4)
#%% T 80s
    

k_range = range(88,98)
if reverse_time:
    k_range = range(98-1,88-1,-1)    


T_1990_to_1980, integr_time = load_trans_mat(k_range, tau_w, lin, reverse_time, tol)


#%%
    
dec = '80s'
def _worker(params):
    c, T_clust = params
    return T_clust[:,list(c)].sum()

    
T_clust_1990_to_1980 = []
T_clust_clust_1990_to_1980 = []

part_70s_forw = multi_res[tau_w]['partitions_forw'][3]

for i,clust in enumerate(clusts_80s_forw):
    
    print(i)
    
    
    T_clust = T_1990_to_1980[list(clust),:]
        
    with Pool(nproc) as pool:
        T_clust_1990_to_1980.append(pool.map(_worker, [(c, T_clust) for c in part_70s_forw.cluster_list]))
    
        
    
    T_clust_clust_1990_to_1980.append(np.array(T_clust_1990_to_1980[i])/sum(T_clust_1990_to_1980[i]))
    

dec = '80s'
# pd.to_pickle(T_clust_clust_80s, f'../paper_data/aps/{network}_{dire}_flow_prob_tau_w{tau_w:.3e}_clustered_T_{dec}.pickle',
#                   protocol=4)    

#%% top clusts from 70s
dec = '70s'

T_clust_2010_to_1980 = T_clust_2010_to_1990[clusts_80s_forw_idx] @ np.vstack(T_clust_clust_1990_to_1980)


# how many have more than 5% prob
print(np.sort(T_clust_2010_to_1980)[::-1][:10])
num_5p = (T_clust_2010_to_1980[np.argsort(T_clust_2010_to_1980)[::-1]]>0.04).sum()
print(num_5p)




clusts_70s_forw_idx = np.argsort(T_clust_2010_to_1980)[::-1][:num_5p]


clusts_70s_1990_to_1980_forw_idx = []
for p in T_clust_clust_1990_to_1980:
    clusts_70s_1990_to_1980_forw_idx.append(np.argsort(p)[::-1][:sum(np.sort(p)[::-1]>0.04)])



clusts_70s_forw = []
for clusts_idx in clusts_70s_forw_idx:
    clusts_70s_forw.append(part_70s_forw.cluster_list[clusts_idx])
    
# print(Counter([node_to_journal_dict[n] for n in clusts_70s_forw[0][0]]).most_common(10))
# print(Counter([node_to_country_dict[n] for n in clusts_70s_forw[0][0]]).most_common())



dec = '70s'
# pd.to_pickle(clusts_70s_forw,os.path.join('../paper_data/aps',f'{network}_{dire}_top_clusts_tau_w{tau_w:.3e}_{dec}.pickle'),
#               protocol=4)

#%%
pd.to_pickle({'orig_clust':orig_clust,
              'orig_clust_id':orig_clust_id,
              'orig_clust_name':orig_clust_name,
              'T_clust_clust_2010_to_2000':T_clust_clust_2010_to_2000,
              'T_clust_clust_2000_to_1990':T_clust_clust_2000_to_1990,
              'T_clust_clust_1990_to_1980':T_clust_clust_1990_to_1980,
              'T_clust_2010_to_1990' :T_clust_2010_to_1990,
              'T_clust_2010_to_1980' :T_clust_2010_to_1980,
              'clusts_90s_forw':clusts_90s_forw,
              'clusts_80s_forw':clusts_80s_forw,
              'clusts_70s_forw':clusts_70s_forw,
              'clusts_90s_forw_idx':clusts_90s_forw_idx,
              'clusts_80s_forw_idx':clusts_80s_forw_idx,
              'clusts_70s_forw_idx':clusts_70s_forw_idx,
              'clusts_80s_2000_to_1990_forw_idx':clusts_80s_2000_to_1990_forw_idx,
              'clusts_70s_1990_to_1980_forw_idx':clusts_70s_1990_to_1980_forw_idx,
              },
             os.path.join('../paper_data/aps',f'{network}_{dire}_tau_w{tau_w:.3e}_prob_flow_integ_{orig_clust_name}_clust.pickle'),
             protocol=4)
